{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# 编写定制 Pass\n",
        "\n",
        "\n",
        "**原作者**: [Jian Weng](https://were.github.io)\n",
        "\n",
        "TVM 是抽象出机器学习加速器异质性（heterogenity）的框架。有时用户可能希望定制一些分析和 IR 变换，使 TVM 适应他们自己的专用硬件。本教程帮助用户在 TVM 中编写定制的pass。\n",
        "\n",
        "## 前提条件\n",
        "\n",
        "在阅读本教程开始之前，假设读者已经很好地了解了以下主题：\n",
        "\n",
        "- 在 TVM 中编写算法并对其进行调度。否则，请参阅示例教程，如 {ref}`opt-gemm`。\n",
        "- HalideIR 的基本结构。否则，请参见 ``HalideIR/src/ir/IR.h`` 来了解 IR 节点定义了哪些属性。\n",
        "- 访问者设计模式（Visitor design pattern）。否则，请查看 [Python AST 模块](https://docs.python.org/3/library/ast.html)，查看 AST visitor 是如何实现的。\n",
        "- 如何将 Schedule 降格（lower）为 IRModule class 或 LLVM module。否则，请查看 ``python/tvm/build_module.py`` 以获得一些基础知识。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import tvm\n",
        "from tvm import te"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "编写非常简单的向量加法，并使用默认的调度来构建它。然后，使用定制的 lower pass 直接操作 IR，而不是使用调度原语（primitives.）。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
              "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
              "\n",
              "\n",
              "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
              "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
              "    <span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func\n",
              "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(\n",
              "        a: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">128</span>,), <span style=\"color: #BA2121\">&quot;float32&quot;</span>),\n",
              "        b: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">128</span>,), <span style=\"color: #BA2121\">&quot;float32&quot;</span>),\n",
              "        c: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">128</span>,), <span style=\"color: #BA2121\">&quot;float32&quot;</span>),\n",
              "    ):\n",
              "        T<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr(\n",
              "            {\n",
              "                <span style=\"color: #BA2121\">&quot;from_legacy_te_schedule&quot;</span>: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>bool(<span style=\"color: #008000; font-weight: bold\">True</span>),\n",
              "                <span style=\"color: #BA2121\">&quot;global_symbol&quot;</span>: <span style=\"color: #BA2121\">&quot;main&quot;</span>,\n",
              "                <span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>bool(<span style=\"color: #008000; font-weight: bold\">True</span>),\n",
              "            }\n",
              "        )\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> i <span style=\"color: #008000; font-weight: bold\">in</span> range(<span style=\"color: #008000\">128</span>):\n",
              "            c_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">128</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>c<span style=\"color: #AA22FF; font-weight: bold\">.</span>data)\n",
              "            a_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">128</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>a<span style=\"color: #AA22FF; font-weight: bold\">.</span>data)\n",
              "            b_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">128</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>b<span style=\"color: #AA22FF; font-weight: bold\">.</span>data)\n",
              "            c_1[i] <span style=\"color: #AA22FF; font-weight: bold\">=</span> a_1[i] <span style=\"color: #AA22FF; font-weight: bold\">+</span> b_1[i]\n",
              "</pre></div>\n"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "n = tvm.tir.const(128, \"int32\")\n",
        "a = te.placeholder((n,), name=\"a\")\n",
        "b = te.placeholder((n,), name=\"b\")\n",
        "c = te.compute((n,), lambda i: a[i] + b[i], name=\"c\")\n",
        "\n",
        "sch = te.create_schedule(c.op)\n",
        "ir = tvm.lower(sch, [a, b, c])\n",
        "ir.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 编写 Pass\n",
        "\n",
        "本质上，“IR 变换 pass” 是将语句映射到新语句的函数。因此，下面定义了一个向量化函数，并逐步实现它。\n",
        "\n",
        "TVM 已经为用户提供了两个类来分析和变换 IR。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### IR Visitor\n",
        "\n",
        "可以使用 ``tvm.tir.stmt_functor.post_order_visit(stmt, func)`` 从 Halide IR 收集信息。``func`` 是回调函数。该函数将在退出当前 IR 节点之前调用，即后序访问（post-order visit）。然后利用副作用来存储 IR 访问的结果，因为 ``func`` 的返回值会被忽略。\n",
        "\n",
        "```{note}\n",
        ":class: alert alert-info\n",
        "\n",
        "你必须使用一些数组来存储 IR 访问的结果。甚至值也是 single 变量。这主要是由于 Python-C 运行时中的约束。每次递归都会刷新变量值，但保留数组值。\n",
        "```"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "loops = []\n",
        "\n",
        "def find_width8(op):\n",
        "    \"\"\"找出所有范围能被 8 除的 'tir.For' 节点。\"\"\"\n",
        "    if isinstance(op, tvm.tir.For):\n",
        "        if isinstance(op.extent, tvm.tir.IntImm):\n",
        "            if op.extent.value % 8 == 0:\n",
        "                loops.append(op)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### IR 变换\n",
        "\n",
        "变换（transformation）接口与 visitor 接口略有不同。在 visitor 中只有 post-order 回调，但是  transformation visitor 同时支持 pre-order 和 post-order 回调。如果您想保留原始 IR 节点，只需返回 None。如果您想将当前节点更改为某个节点，请使用 TVM IR maker 接口来构建它并返回此值。\n",
        "\n",
        "```{note}\n",
        ":class: alert alert-info\n",
        "\n",
        "如果 pre-order 函数被调用并返回非 None 的值，则 post-order 函数将被跳过。\n",
        "```"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def vectorize8(op):\n",
        "    \"\"\"Split can vectorize the loops found in `find_width8`.\"\"\"\n",
        "    if op in loops:\n",
        "        extent = op.extent.value\n",
        "        name = op.loop_var.name\n",
        "        lo, li = te.var(name + \".outer\"), te.var(name + \".inner\")\n",
        "        body = tvm.tir.stmt_functor.substitute(op.body, {op.loop_var: lo * 8 + li})\n",
        "        body = tvm.tir.For(li, 0, 8, tvm.tir.ForKind.VECTORIZED, body)\n",
        "        body = tvm.tir.For(lo, 0, extent // 8, tvm.tir.ForKind.SERIAL, body)\n",
        "        return body\n",
        "    return None\n",
        "\n",
        "\n",
        "@tvm.tir.transform.prim_func_pass(opt_level=0)\n",
        "def vectorize(f, mod, ctx):\n",
        "    global loops\n",
        "    tvm.tir.stmt_functor.post_order_visit(f.body, find_width8)\n",
        "    if not loops:\n",
        "        return f\n",
        "    # 最后一个 list 参数表示要转换的节点类型。\n",
        "    # 因此，在这种情况下，只有 `For` 节点会调用 `vectorize8`\n",
        "    return f.with_body(tvm.tir.stmt_functor.ir_transform(f.body, None, vectorize8, [\"tir.For\"]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Glue 到 lower pass\n",
        "\n",
        "到目前为止，已经完成了这个 IR 变换过程。接下来需要做的是将这个 pass 粘合到 TVM 的 lower pass 上。\n",
        "\n",
        "在本例中，通过向 ``tir.add_lower_pass`` 提供元组参数列表，将上面编写的 pass 注入到 TVM 标准 lower pass 中。\"Tuple\" 表明 lower 的不同阶段。在 TVM 中，有四个 lower 阶段，每个阶段(phase)完成后将调用用户自定义的阶段。\n",
        "\n",
        "```{note}\n",
        ":class: alert alert-info\n",
        "以下是每个阶段所做的基本变换：\n",
        "\n",
        "- 阶段 0：生成 raw IR 和循环级别（loop levels）。\n",
        "- 阶段 1：对 array storage 进行扁平化（flatten）。\n",
        "- 阶段 2：变换循环（transforms loops）：如 unroll、vectorization 和 thread-binding。\n",
        "- 阶段 3：做一些清理工作。\n",
        "```\n",
        "\n",
        "因此，将这个变换过程放置在阶段 1 之后是一个很好的地方。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
              "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
              "\n",
              "\n",
              "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
              "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
              "    <span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func\n",
              "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(\n",
              "        a: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">128</span>,), <span style=\"color: #BA2121\">&quot;float32&quot;</span>),\n",
              "        b: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">128</span>,), <span style=\"color: #BA2121\">&quot;float32&quot;</span>),\n",
              "        c: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">128</span>,), <span style=\"color: #BA2121\">&quot;float32&quot;</span>),\n",
              "    ):\n",
              "        T<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr(\n",
              "            {\n",
              "                <span style=\"color: #BA2121\">&quot;from_legacy_te_schedule&quot;</span>: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>bool(<span style=\"color: #008000; font-weight: bold\">True</span>),\n",
              "                <span style=\"color: #BA2121\">&quot;global_symbol&quot;</span>: <span style=\"color: #BA2121\">&quot;main&quot;</span>,\n",
              "                <span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>bool(<span style=\"color: #008000; font-weight: bold\">True</span>),\n",
              "            }\n",
              "        )\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> i_outer <span style=\"color: #008000; font-weight: bold\">in</span> range(<span style=\"color: #008000\">16</span>):\n",
              "            cse_var_1: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32 <span style=\"color: #AA22FF; font-weight: bold\">=</span> i_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">8</span>\n",
              "            c_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">128</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>c<span style=\"color: #AA22FF; font-weight: bold\">.</span>data)\n",
              "            a_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">128</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>a<span style=\"color: #AA22FF; font-weight: bold\">.</span>data)\n",
              "            b_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">128</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>b<span style=\"color: #AA22FF; font-weight: bold\">.</span>data)\n",
              "            c_1[cse_var_1 : cse_var_1 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">8</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                a_1[cse_var_1 : cse_var_1 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">8</span>] <span style=\"color: #AA22FF; font-weight: bold\">+</span> b_1[cse_var_1 : cse_var_1 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">8</span>]\n",
              "            )\n",
              "</pre></div>\n"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "with tvm.transform.PassContext(config={\"tir.add_lower_pass\":\n",
        "                                       [(1, vectorize)]}):\n",
        "    tvm.lower(sch, [a, b, c]).show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "[for i in range(128):\n",
              "     c = T.Buffer((128,))\n",
              "     a = T.Buffer((128,))\n",
              "     b = T.Buffer((128,))\n",
              "     c[i] = a[i] + b[i]]"
            ]
          },
          "execution_count": 6,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "loops"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 快速视图\n",
        "\n",
        "本教程提供了编写自定义 IR 变换 pass 的快速视图：\n",
        "\n",
        "- 使用 ``tvm.tir.stmt_functor.post_order_visit`` 收集每个 IR 节点的信息。\n",
        "- 使用 ``tvm.tir.stmt_functor.ir_transform`` 变换 IR 节点。\n",
        "- 包装上面的两个，写出 IR-transformation 函数。\n",
        "- 使用 ``tvm.transform.PassContext`` 将该函数放入 TVM lowering pass"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3.8.13 ('py38': conda)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.11"
    },
    "vscode": {
      "interpreter": {
        "hash": "28558e8daad512806f5c536a1a04c119185f99f65b79002708a12162d02a79c7"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
